# Proximity Analysis Helper Functions
# Alex F. Gazmararian
# agazmararian@gmail.com

# Suppress R CMD check notes for global variables
utils::globalVariables(c("estimate", "conf.low", "conf.high", "model", "d_mfg_open_q"))

# Required packages (assumes they are already loaded in the calling environment)
# library(dplyr)
# library(ggplot2) 
# library(modelsummary)
# library(tinytable)
# library(fixest)
# library(stringr)
# library(broom)
# Note: Functions like est_model should be available from the calling environment
# The pipe operator %>% comes from dplyr/magrittr which should already be loaded

#' Estimate models across multiple outcomes for proximity analysis
#' @param var.list Character vector of treatment variables
#' @param data.in Input dataset
#' @param outcomes Character vector of outcome variables (default: c("greenproj_bin", "credit_biden_bin", "greenbenefit_bin"))
#' @param cutoff Distance cutoff for Conley SEs (default: 50)
#' @param se Standard error type (default: "conley")
#' @param weights Survey weights (optional)
#' @param conley.distance Distance metric for Conley SEs (default: "triangular")
#' @param fe Fixed effects specification (default: "| state")
#' @param ... Additional arguments passed to est_model
#' @return Named list of model results organized by outcome then treatment variable
estimate_proximity_models <- function(var.list, data.in, 
                                     outcomes = c("greenproj_bin", "credit_biden_bin", "greenbenefit_bin"),
                                     cutoff = 50, se = "conley", weights = NULL, 
                                     conley.distance = "triangular", fe = "| state",
                                     model_name_suffix = "", ...) {
  
  result <- list()
  
  for (outcome in outcomes) {
    message("Estimating models for outcome: ", outcome)
    
    # Estimate models for each treatment variable
    # Note: est_model function should be available in the calling environment
    models <- lapply(var.list, function(var) {
      if (exists("est_model", envir = .GlobalEnv)) {
        # Generate model name with optional suffix for subset analyses
        model_name <- if (model_name_suffix != "") {
          paste0(outcome, "_", var, "_", se, "_", cutoff, model_name_suffix)
        } else {
          NULL  # Let est_model generate default name
        }
        
        est_model(var, outcome = outcome, data.in = data.in, 
                 cutoff = cutoff, se = se, weights = weights, 
                 conley.distance = conley.distance, fe = fe,
                 model_name = model_name, ...)
      } else {
        stop("est_model function not found. Please ensure it is loaded.")
      }
    })
    names(models) <- var.list
    
    result[[outcome]] <- models
  }
  
  return(result)
}

#' Generate proximity effect plot
#' @param models List of models for a single outcome
#' @param title Plot title
#' @param y_label Y-axis label
#' @param data.in Dataset for sample size calculation
#' @param outcome Outcome variable name for sample size calculation
#' @param legend.position Legend position (default: c(.15, .1))
#' @param y_limits Y-axis limits (default: c(-.15, .15))
#' @return ggplot object
create_proximity_plot <- function(models, title, y_label, data.in, outcome, 
                                 legend.position = c(.15, .1), y_limits = c(-.15, .15)) {
  
  # Convert to list if needed and add model names
  if(length(models) == 2) {
    model_list <- list(models[[1]], models[[2]])
  } else {
    model_list <- models
  }
  
  # Note: These functions should be available in the calling environment
  p <- model_list %>%
    modelsummary::modelplot(draw = FALSE) %>%
    dplyr::mutate(model = dplyr::case_when(grepl("_re_", term) ~ "Renewable Energy", TRUE ~ "Manufacturing")) %>%
    dplyr::filter(grepl("^d_", term)) %>%
    dplyr::mutate(term = stringr::str_extract(term, "Q[1-4]")) %>%
    ggplot2::ggplot(ggplot2::aes(x = term, y = estimate, ymin = conf.low, ymax = conf.high, 
               color = model, shape = model, group = model)) +
    ggplot2::geom_hline(yintercept = 0, color = "grey") +
    ggplot2::geom_pointrange(size = 1, linewidth = 1, position = ggplot2::position_dodge(width = 0.5)) +
    ggplot2::theme_classic(base_size = 10) +
    ggplot2::scale_color_grey() +
    ggplot2::scale_y_continuous(limits = y_limits) +
    ggplot2::annotate("text", x = Inf, y = -Inf, label = paste0("N = ", nrow(subset(data.in, !is.na(data.in[[outcome]])))), 
             hjust = 1.1, vjust = -0.5, 
             size = 3, color = "grey40") +
    ggplot2::labs(
      x = "Distance quintiles (reference: Q5 = farthest group)",
      y = y_label,
      title = title,
      color = NULL,
      shape = NULL
    ) +
    ggplot2::theme(
      legend.position.inside = if(is.numeric(legend.position)) legend.position else NULL,
      legend.position = if(!is.numeric(legend.position)) legend.position else "inside",
      legend.background = ggplot2::element_blank()
    )
  
  return(p)
}

#' Generate standardized modelsummary table for proximity effects
#' @param models Named list of models 
#' @param title Table title
#' @param notes Table notes
#' @param column_labels Character vector of column labels
#' @param filename Output filename
#' @param label LaTeX label
#' @param coef_map Coefficient map (default: uses global coefmap for distance terms)
#' @param resize_width Resize width (default: 0.45)
#' @param resize_direction Resize direction (default: "both")
#' @param fixed_effects_label Label for fixed effects row (default: "State Fixed Effects")
#' @return tinytable object
create_proximity_table <- function(models, title, notes, column_labels, filename, label,
                                  coef_map = NULL, resize_width = 0.45, resize_direction = "both",
                                  fixed_effects_label = "State Fixed Effects") {
  
  # If no coef_map provided, use the full global coefmap
  # modelsummary will automatically only show coefficients that exist in the models
  if(is.null(coef_map)) {
    if(exists("coefmap", envir = .GlobalEnv)) {
      coef_map <- get("coefmap", envir = .GlobalEnv)
    } else {
      # Fallback: let modelsummary handle it automatically
      coef_map <- NULL
    }
  }
  
  # Create group structure for column headers
  n_models <- length(models)
  if(n_models == 6) {
    group_list <- list("Visibility (=1)" = 2:3, "Credit Biden (=1)" = 4:5, "Benefits (=1)" = 6:7)
  } else if(n_models == 10) {
    group_list <- list("Governor (=1)" = 2:3, "State lawmakers (=1)" = 4:5, "Congress (=1)" = 6:7, 
                      "Local officials (=1)" = 8:9, "Markets (=1)" = 10:11)
  } else {
    # Generic grouping for other cases
    group_list <- list()
    for(i in seq_along(column_labels)) {
      start_col <- (i-1)*2 + 2
      end_col <- i*2 + 1
      group_list[[paste0(column_labels[i], " (=1)")]] <- start_col:end_col
    }
  }
  
  # Get global gm if available
  if(exists("gm", envir = .GlobalEnv)) {
    gm_use <- get("gm", envir = .GlobalEnv)
  } else {
    gm_use <- modelsummary::gof_map
    gm_use$omit <- TRUE
    gm_use[gm_use$raw == "adj.r.squared", ]$clean <- "Adjusted $R^2$"
    gm_use[gm_use$raw == "adj.r.squared", ]$omit <- FALSE
    gm_use[gm_use$raw == "nobs", ]$clean <- "$N$"
    gm_use[gm_use$raw == "nobs", ]$omit <- FALSE
  }
  
  # Append label to title for LaTeX referencing
  title_with_label <- paste0(title, " \\label{", label, "}")
  
  table_out <- modelsummary::modelsummary(
    models,
    title = title_with_label,
    notes = notes,
    coef_map = coef_map,
    stars = c("*" = 0.05, "**" = 0.01, "***" = 0.001),
    add_rows = data.frame(
      term = c("Covariates", "Sample Fixed Effects", fixed_effects_label),
      matrix("Yes", nrow = 3, ncol = n_models, 
             dimnames = list(NULL, names(models)))
    ),
    gof_omit = "IC|RMSE|Std|FE|Within",
    gof_map = gm_use,
    output = "tinytable",
    escape = FALSE,
    fmt = modelsummary::fmt_significant(2)
  ) %>%
    tinytable::group_tt(j = group_list) %>%
    tinytable::theme_latex(resize_width = resize_width, resize_direction = resize_direction, placement = "H") %>%
    tinytable::save_tt(filename, overwrite = TRUE)
  
  return(table_out)
}

#' Run robustness analysis with different specifications
#' @param var.list Character vector of treatment variables
#' @param data.in Input dataset
#' @param outcomes Character vector of outcomes
#' @param specs Named list of specification parameters
#' @return Named list of model results by specification
run_robustness_analysis <- function(var.list, data.in, 
                                  outcomes = c("greenproj_bin", "credit_biden_bin", "greenbenefit_bin"),
                                  specs = list(
                                    "300km" = list(cutoff = 300),
                                    "500km" = list(cutoff = 500),
                                    "spherical" = list(conley.distance = "spherical"),
                                    "state_se" = list(se = "state"),
                                    "geo_subset" = list(subset_condition = "loc_anom == 0")
                                  )) {
  
  result <- list()
  
  for(spec_name in names(specs)) {
    message("Running robustness analysis: ", spec_name)
    spec_params <- specs[[spec_name]]
    
    # Handle subset condition
    if("subset_condition" %in% names(spec_params)) {
      # Extract variable name from subset condition for validation
      subset_var <- gsub("\\s*==.*", "", spec_params$subset_condition)
      subset_var <- trimws(subset_var)
      
      # Check if required variable exists in data
      if (!subset_var %in% names(data.in)) {
        stop("Variable '", subset_var, "' required for '", spec_name, 
             "' robustness analysis not found in data.",
             "\n  For anonymized mode, ensure this variable is included in the FIPS cache.")
      }
      
      data_use <- subset(data.in, eval(parse(text = spec_params$subset_condition)))
      spec_params$subset_condition <- NULL  # Remove from params to pass to est_model
      
      # Use distinct model name suffix for subset models to cache vcov correctly
      model_suffix <- paste0("_", spec_name)
    } else {
      data_use <- data.in
      model_suffix <- paste0("_", spec_name)  # Always use spec name as suffix for robustness
    }
    
    # Estimate models for this specification
    spec_models <- estimate_proximity_models(
      var.list = var.list,
      data.in = data_use,
      outcomes = outcomes,
      cutoff = spec_params$cutoff %||% 50,
      se = spec_params$se %||% "conley",
      conley.distance = spec_params$conley.distance %||% "triangular",
      weights = spec_params$weights,
      fe = spec_params$fe %||% "| state",
      model_name_suffix = model_suffix
    )
    
    result[[spec_name]] <- spec_models
  }
  
  return(result)
}

#' Check within-state variation for an outcome
#' @param outcome Outcome variable name
#' @param data.in Input dataset
#' @param models List of models to extract coefficients from
#' @param var.list Treatment variable names
#' @return Prints variation statistics
check_within_state_variation <- function(outcome, data.in, models, var.list) {
  message("Checking within-state variation in the ", outcome, " outcome")
  
  # Calculate within-state variation in outcome
  y.resid <- fixest::feols(as.formula(paste(outcome, "~ 1 | state")), data = data.in) %>%
    resid() %>%
    sd()
  message("Within-state variation in the ", outcome, " outcome: ", y.resid)
  writeLines(format(y.resid, digits = 2), paste0("output/pnas/stats/within_state_variation_", outcome, ".txt"))
  
  # Calculate share of within-state variation explained by each treatment
  for(var in var.list) {
    if(var %in% names(models)) {
      coef_val <- broom::tidy(models[[var]])[1,2]
      var_label <- ifelse(grepl("_re_", var), "renewable", "manufacturing")
      message("Share of within-state variation in ", outcome, " (", var_label, "): ", coef_val/y.resid)
      var.explained <- coef_val/y.resid
      var.explained.pct <- var.explained$estimate*100
      writeLines(format(round(var.explained.pct, 0), 0), paste0("output/pnas/stats/share_of_within_state_variation_", outcome, "_", var_label, ".txt"))
    }
  }
  
  # Check within-state variation in treatment (using manufacturing as example)
  if("d_mfg_open_q" %in% var.list) {
    message("Checking within-state variation in treatment")
    d.resid <- fixest::feols(I(d_mfg_open_q == "Q1") ~ 1 | state, 
                    data = subset(data.in, d_mfg_open_q %in% c("Q1", "Q5"))) %>%
      resid() %>%
      sd()
    message("Within-state variation in treatment: ", d.resid)
    writeLines(format(d.resid, digits = 2), paste0("output/pnas/stats/within_state_variation_treatment.txt"))
    
    for(var in var.list) {
      if(var %in% names(models)) {
        coef_val <- broom::tidy(models[[var]])[1,2]
        var_label <- ifelse(grepl("_re_", var), "Renewable", "Manufacturing")
        message(var_label, " effect size / within-state variation in treatment: ", coef_val/d.resid)
        var.explained <- coef_val/d.resid
        var.explained.pct <- var.explained$estimate*100
        writeLines(format(round(var.explained.pct, 0), 0), paste0("output/pnas/stats/share_of_within_state_variation_treatment_", var_label, ".txt"))
      }
    }
  }
}

# Null-coalescing operator helper
`%||%` <- function(x, y) {
  if(is.null(x)) y else x
}
